import json
import os
import sys
import getopt

global cwd
cwd = os.getcwd()
global abs_path
abs_path = os.path.abspath(os.path.dirname(__file__)) 
if cwd != abs_path:
    os.chdir(abs_path)

basic_type = ["float", "uint32", "bool", "int32", "string", "int64", "uint64"]

class NodeGenerator:
    def __init__(self, file_path, out_dir):
        self.file_path = file_path
        self.out_dir = out_dir
        self.gen_plato_base_cpp()
        self.gen_structs()

    def gen_plato_base_cpp(self):
        file = open(self.file_path).read()
        cont = json.loads(file)
        for category in cont["categories"]:
            for node in category["nodes"]:
                self.gen_node_cpp(node)

    def gen_node_cpp(self, node):
        node_name = node["title"]
        static_id = node["id"]
        file = open(os.path.join(self.out_dir,node_name) + ".cpp", "w")
        file.write("#pragma once\n\n")
        file.write('#include "plato_node.hh"\n')
        file.write('#include "plato_node_creator.hh"\n')
        for pin in node['inputs']:
            if pin["type"] == 'exec':
                continue
            if not self.is_basic_type(pin["type"]):
                file.write('#include "'+pin["type"]+'.hh"\n')
        for pin in node['outputs']:
            if pin["type"] == 'exec':
                continue
            if not self.is_basic_type(pin["type"]):
                file.write('#include "'+pin["type"]+'.hh"\n')
        file.write('\n')
        file.write('extern auto add_node_creator(const char *name, plato::PlatoNodeCreator *creator) -> bool;\n\n')
        file.write('namespace plato { namespace node {\n\n')
        file.write('struct ' + node_name + ' : public PlatoNodeCreator {\n')
        file.write('  using PinNameIndexMap = std::unordered_map<std::string, PlatoPinIndex>;\n')
        file.write('  PinNameIndexMap pin_name_input_index_map_;\n')
        file.write('  PinNameIndexMap pin_name_output_index_map_;\n')
        if "isServer" in node:
            file.write('  PlatoNodeSyncType sync_type_{PlatoNodeSyncType::SERVER_SIDE};\n')
        elif "isClient" in node:
            file.write('  PlatoNodeSyncType sync_type_{PlatoNodeSyncType::CLIENT_SIDE};\n')
        else:
            file.write('  PlatoNodeSyncType sync_type_{PlatoNodeSyncType::NONE};\n')
        file.write('  std::size_t var_mem_size_{0};\n')
        file.write('  '+node_name+'() {\n')
        pin_index = 0
        for pin in node['inputs']:
            file.write('    pin_name_input_index_map_.emplace("'+pin['title']+'",'+str(pin_index) + ');\n')
            pin_index += 1
        pin_index = 0
        for pin in node['outputs']:
            file.write('    pin_name_output_index_map_.emplace("'+pin['title']+'",'+str(pin_index) + ');\n')
            pin_index += 1
        file.write('  }\n')
        file.write('  virtual ~'+node_name+'() {}\n')
        file.write('  virtual auto create(DomainPtr domain_ptr, PlatoNodeID id) -> PlatoNodePtr override {\n')
        file.write('    auto var_sync_type = PlatoVariableSyncType::NONE;\n')
        file.write('    if (sync_type_ == PlatoNodeSyncType::SERVER_SIDE) {\n')
        file.write('      var_sync_type = PlatoVariableSyncType::SENDER;\n')
        file.write('    } else if (sync_type_ == PlatoNodeSyncType::CLIENT_SIDE) {\n')
        file.write('      var_sync_type = PlatoVariableSyncType::RECEIVER;\n')
        file.write('    }\n')
        file.write('    auto node_ptr = new_node(domain_ptr, PlatoNodeID(' + str(static_id) + '), id, sync_type_);\n')
        file.write('    node_ptr->set_name("'+node_name+'");\n')
        for pin in node['inputs']:
            self.gen_node_pin(file, pin, True)
        for pin in node['outputs']:
            self.gen_node_pin(file, pin, False)
        file.write('    return node_ptr;\n')
        file.write('  }\n')
        file.write('  virtual auto static_id() -> PlatoNodeID override {\n')
        file.write('    return PlatoNodeID(' + str(static_id) + ');\n')
        file.write('  }\n')
        file.write('  virtual auto get_pin_input_index(const std::string& pin_name) -> PlatoPinIndex override {\n')
        file.write('    auto it = pin_name_input_index_map_.find(pin_name);\n')
        file.write('    if (it == pin_name_input_index_map_.end()) {\n')
        file.write('      return INVALID_PIN_INDEX;\n')
        file.write('    }\n')
        file.write('    return it->second;\n')
        file.write('  }\n')
        file.write('  virtual auto get_pin_output_index(const std::string& pin_name) -> PlatoPinIndex override {\n')
        file.write('    auto it = pin_name_output_index_map_.find(pin_name);\n')
        file.write('    if (it == pin_name_output_index_map_.end()) {\n')
        file.write('      return INVALID_PIN_INDEX;\n')
        file.write('    }\n')
        file.write('    return it->second;\n')
        file.write('  }\n')
        file.write('  virtual auto get_node_memory_size() -> std::size_t override {\n')
        for pin in node['inputs']:
            self.gen_node_pin_var_mem_size(file, pin)
        for pin in node['outputs']:
            self.gen_node_pin_var_mem_size(file, pin)
        file.write('    return var_mem_size_ + get_node_size() + get_pin_size() * (pin_name_input_index_map_.size() + pin_name_output_index_map_.size());\n')
        file.write('  }\n')
        file.write('};\n\n')
        file.write('static '+node_name+ ' creator;\n\n')
        file.write('static auto res = add_node_creator("' + node_name + '", &creator);\n\n')
        file.write('}}\n\n')

    def get_var_type(self, str_type):
        return str_type.capitalize()

    def is_basic_type(self, str_type):
        return str_type in basic_type

    def gen_node_pin(self, file, pin, is_input):
        var_type = pin["type"]
        if var_type == 'exec':
            file.write('    auto pin_'+pin['title'] +' = new_pin(domain_ptr, PlatoPinType::EXEC, nullptr);\n')
        else:
            if var_type == 'seq':
                file.write('    auto var_' + pin['title'] + ' = domain_ptr->New<Array<'+self.get_var_type(pin['key'])+'>>(var_sync_type);\n')
            elif var_type == 'set':
                file.write('    auto var_' + pin['title'] + ' = domain_ptr->New<Set<'+self.get_var_type(pin['key'])+'>>(var_sync_type);\n')
            elif var_type == 'dict':
                file.write('    auto var_' + pin['title'] + ' = domain_ptr->New<Map<'+self.get_var_type(pin['key'])+','+self.get_var_type(pin['value'])+'>>(var_sync_type);\n')
            else:
                file.write('    auto var_' + pin['title'] + ' = domain_ptr->New<'+self.get_var_type(var_type)+'>(var_sync_type);\n')
            file.write('    auto pin_'+pin['title'] +' = new_pin(domain_ptr, PlatoPinType::VAR, var_'+pin['title']+');\n')
        if is_input:
            file.write('    node_ptr->add_input(pin_' + pin['title']+');\n')
        else:
            file.write('    node_ptr->add_output(pin_' + pin['title']+');\n')

    def gen_node_pin_var_mem_size(self, file, pin):
        var_type = pin["type"]
        if var_type != 'exec':
            if var_type == 'seq':
                file.write('    var_mem_size_ += sizeof(Array<'+self.get_var_type(pin['key'])+'>);\n')
            elif var_type == 'set':
                file.write('    var_mem_size_ += sizeof(Set<'+self.get_var_type(pin['key'])+'>);\n')
            elif var_type == 'dict':
                file.write('    var_mem_size_ += sizeof(Map<'+self.get_var_type(pin['key'])+','+self.get_var_type(pin['value'])+'>);\n')
            else:
                if self.is_basic_type(var_type):
                    file.write('    var_mem_size_ += sizeof('+self.get_var_type(var_type)+');\n')
                else:
                    file.write('    var_mem_size_ += '+self.get_var_type(var_type)+'::get_fixed_mem_size();\n')

    def gen_structs(self):
        cont = json.loads(open(self.file_path).read())
        for struct in cont["structs"]:
            self.gen_struct_hh(struct)
            self.gen_struct_cpp(struct)

    def gen_struct_hh(self, struct):
        struct_name = struct["title"]
        file = open(os.path.join(self.out_dir,struct_name)+'.hh', "w")
        file.write('#pragma once\n\n')
        file.write('#include "plato_variable.hh"\n\n')
        file.write('namespace plato {\n\n')
        for field in struct["fields"]:
            if not self.is_basic_type(field['type']):
                file.write("class " + field['type'] + ';\n')
                file.write("using " + field['type'] + 'Ptr = std::shared_ptr<' + field['type'] + '>;\n\n')
        file.write('class '+struct_name+ ' : public StructVariable {\n')
        file.write('  '+struct_name+'() = delete;\n')
        file.write('  '+struct_name+'(const '+struct_name+' &) = delete;\n')
        file.write('  '+struct_name+'('+struct_name+' &&) = delete;\n\n')
        file.write('public:\n')
        for field in struct["fields"]:
            if field['type'] == 'seq':
                file.write('  ArrayPtr<'+self.get_var_type(field['key'])+'> '+field['title']+';\n')
            elif field['type'] == 'set':
                file.write('  SetPtr<'+self.get_var_type(field['key'])+'> '+field['title']+';\n')
            elif field['type'] == 'dict':
                file.write('  MapPtr<'+self.get_var_type(field['key'])+','+self.get_var_type(field['value']) +'> '+field['title']+';\n')
            else:
                file.write('  '+self.get_var_type(field['type'])+'Ptr ' + field['title']+';\n')
        if len(struct["fields"]) > 0:
            file.write('\n')
        file.write('  '+struct_name+'(Domain *domain, VarID parent, PlatoVariableSyncType sync_type);\n')
        file.write('  virtual ~' + struct_name + '();\n')
        file.write('  auto static New(Domain *domain, VarID parent, PlatoVariableSyncType sync_type) -> std::shared_ptr<'+ struct_name +'>;\n')
        file.write('  virtual auto serialize(PlatoStream &stream) -> bool override;\n')
        file.write('  virtual auto deserialize(PlatoStream &stream) -> bool override;\n')
        file.write('  virtual auto copy(Variable *other) -> void override;\n')
        file.write('  virtual auto object_size() -> std::size_t override;\n')
        file.write('  virtual auto copy_default() -> void override;\n')
        file.write('  virtual auto complete_prototype() -> void override;\n')
        file.write('  static auto get_fixed_mem_size() -> std::size_t;\n')
        file.write('};\n\n')
        file.write('}\n\n')

    def gen_struct_cpp(self, struct):
        include_headers = []
        struct_name = struct["title"]
        file = open(os.path.join(self.out_dir,struct_name)+'.cpp', "w")
        file.write('#include "'+ struct_name +'.hh"\n')
        for field in struct["fields"]:
            if not self.is_basic_type(field['type']) and not (field['type'] in include_headers):
                file.write('#include "' + field['type'] + '.hh"\n')
                include_headers.append(field['type'])
        file.write("\n")
        file.write('namespace plato {\n\n')
        file.write(struct_name+"::"+struct_name+'(Domain *domain, VarID parent, PlatoVariableSyncType sync_type)\n')
        file.write('  : StructVariable(domain, parent, sync_type)')
        if len(struct["fields"]) > 0:
            file.write(',\n')
        index = 1
        for field in struct["fields"]:
            if field['type'] == 'seq':
                file.write('    '+field['title']+'(domain->New<Array<'+self.get_var_type(field['key'])+'>>(sync_type)')
            elif field['type'] == 'set':
                file.write('    '+field['title']+'(domain->New<Set<'+self.get_var_type(field['key'])+'>>(sync_type)')
            elif field['type'] == 'dict':
                file.write('    '+field['title']+'(domain->New<Map<'+self.get_var_type(field['key'])+','+self.get_var_type(field['value'])+'>>(sync_type)')
            else:
                file.write('    '+field['title']+'(domain->New<'+self.get_var_type(field['type'])+'>(sync_type))')
            if index < len(struct["fields"]):
                file.write(',\n')
            index += 1
        file.write(' {\n  }\n\n')
        file.write(struct_name+"::~"+struct_name+'() {}\n\n')
        file.write('auto ' + struct_name+"::New"+'(Domain *domain, VarID parent, PlatoVariableSyncType sync_type) -> std::shared_ptr<'+struct_name+'> {\n')
        file.write('  return plato::make_shared<'+struct_name+'>(domain->mem_block(), domain, parent, sync_type);\n')
        file.write('}\n\n')
        file.write('auto '+struct_name+'::serialize(PlatoStream &stream) -> bool {\n')
        file.write('  stream << id();\n')
        for field in struct["fields"]:
            file.write('  '+field['title']+'->serialize(stream);\n')
        file.write('  return true;\n')
        file.write('}\n\n')
        file.write('auto '+struct_name+'::deserialize(PlatoStream &stream) -> bool {\n')
        file.write('  stream.skip(sizeof(VarID));\n')
        for field in struct["fields"]:
            file.write('  '+field['title']+'->deserialize(stream);\n')
        file.write('  return true;\n')
        file.write('}\n\n')
        file.write('auto '+struct_name+'::copy(Variable *other) -> void {\n')
        file.write('  auto *ptr = dynamic_cast<' + struct_name + ' *>(other);\n')
        file.write('  if (!ptr) { return; }\n')
        for field in struct["fields"]:
            file.write('  '+field['title']+'->copy(ptr->'+ field['title'] +'.get());\n')
        file.write('}\n\n')
        file.write('auto '+struct_name+'::object_size() -> std::size_t {\n')
        file.write('  return sizeof('+struct_name+');\n')
        file.write('}\n\n')
        file.write('auto '+struct_name+'::copy_default() -> void {\n')
        for field in struct["fields"]:
            file.write('  '+field['title']+'->copy_default();\n')
        file.write('}\n\n')
        file.write('auto '+struct_name+'::complete_prototype() -> void { copy_default(); }\n\n')
        file.write('auto '+struct_name+'::get_fixed_mem_size() -> std::size_t {\n')
        file.write('  std::size_t total_size = sizeof('+struct_name+');\n')
        for field in struct["fields"]:
            if self.is_basic_type(field['type']):
                file.write('  total_size += sizeof('+self.get_var_type(field['type'])+');\n')
            else:
                if field['type'] == 'seq':
                    file.write('  total_size += sizeof(Array<'+self.get_var_type(field['key'])+'>);\n')
                elif field['type'] == 'set':
                    file.write('  total_size += sizeof(Set<'+self.get_var_type(field['key'])+'>);\n')
                elif field['type'] == 'dict':
                    file.write('  total_size += sizeof(Map<'+self.get_var_type(field['key'])+','+self.get_var_type(field['value'])+'>);\n')
                else:
                    file.write('  total_size += '+self.get_var_type(field['type']) + '::get_fixed_mem_size();\n')
        file.write('  return total_size;\n')
        file.write('}\n\n')
        file.write('}\n\n')

if __name__ == "__main__":
    opts = None
    args = None
    opts,args = getopt.getopt(sys.argv[1:],'-f:-o:',['file=','out='])
    file_path = "."
    out_dir = "."
    for opt, value in opts:
        if opt in ("-o", "--out"):
            out_dir = value
        elif opt in ("-f", "--file"):
            file_path = value
    if os.path.exists(file_path):
        NodeGenerator(file_path, out_dir)
    # try:
    #     opts,args = getopt.getopt(sys.argv[1:],'-f:-o',['file=','out='])
    #     file_path = "."
    #     out_dir = "."
    #     for opt, value in opts:
    #         if opt in ("-o", "--out"):
    #             out_dir = value
    #         elif opt in ("-f", "--file"):
    #             file_path = value
    #     if os.path.exists(file_path):
    #         NodeGenerator(file_path, out_dir)
    # except getopt.GetoptError as e:
    #     print(str(e))
